import argparse

from allennlp.data.token_indexers import SingleIdTokenIndexer, PretrainedTransformerIndexer, ELMoTokenCharactersIndexer, \
    TokenCharactersIndexer
from allennlp.common.tqdm import Tqdm

from tools.tokenizers import WordTokenizer, CharacterTokenizer, PretrainedTransformerTokenizer, WhitespaceTokenizer
from tools.tokenizers import StopwordFilter

from tools.color import Color
from victim_model.embedding.embedder_factory import EmbedderFactory


class Config:
    def __init__(self):

        # train
        self.device = '0'
        self.epoch = 10
        self.batch_size = 50
        self.quiet = False
        self.d_ckpt = 'ckpt2'

        # victim model
        self.max_vocab_size = 8000

        # task
        self.dataset = 'agnews'
        self.encoder = '1-lstm'
        self.token = 'word'

        # attack
        self.attacker = 'genetic'
        self.adv_id = 'none'

        self.copyright = ''
        self.parse_arg()

        self.attack_ratio_or_num = {'pwws': 0.3, 'genetic': 0.3, 'universal': 0.3, 'pmi': 0.3, 'random': 0.3}
        self.attack_dataset_size = 1000
        self.transfer_dataset_size = 500

        # dataset
        self.d_dataset = f'/dataset' # !!!change to your dataset directory
        self.d_raw = f'{self.d_dataset}/{self.dataset}/raw'
        self.d_split = f'{self.d_dataset}/{self.dataset}/split'
        self.p_split = {i: f'{self.d_split}/{i}' for i in ['train', 'val', 'test', ]}

        self.model_id = f'{self.dataset}_{self.encoder}_{self.token}'
        if self.adv_id == 'none':
            self.adv_id = f'{self.dataset}_{self.attacker}_{self.encoder}_{self.token}'
        self.p_adv = {i: f'adv/{i}/{self.adv_id}' for i in
                      ['detail', 'adv_examples', 'semi_adv_examples', 'universal', 'pre_universal']}

        self.p_log = {i: f'{i}.txt' for i in ['train', 'attack', 'transfer']}

        self.p_synonym = f'{self.d_dataset}/synonym/ibp-nbrs.json'

        self.encoder_config = {
            "1-lstm": {
                "num_layers": 1,
                "bidirectional": False
            },
            "1-bi-lstm": {
                "num_layers": 1,
                "bidirectional": True
            },
            "2-lstm": {
                "num_layers": 2,
                "bidirectional": False
            },
            "2-bi-lstm": {
                "num_layers": 2,
                "bidirectional": True
            },
            "4-lstm": {
                "num_layers": 4,
                "bidirectional": False
            },
            "4-bi-lstm": {
                "num_layers": 4,
                "bidirectional": True
            },

            "1-cnn": {
                "ngram_filter_sizes": (3,)
            },

            "2-cnn": {
                "ngram_filter_sizes": (7, 3)
            },
            "4-cnn": {
                "ngram_filter_sizes": (7, 7, 3, 3)
            },
            "6-cnn": {
                "ngram_filter_sizes": (7, 7, 3, 3, 3, 3)
            },

            "bert": {
                'pretrained_model': 'bert-base-uncased'
            },
            "roberta": {
                'pretrained_model': 'roberta-base'
            },
            "albert": {
                'pretrained_model': 'albert-base-v2'
            },
        }

        def token_config(token):
            if token in ['word', 'glove', 'word2vec', 'fasttext']:
                return {
                    'tokenizer': WordTokenizer(),
                    'token_indexers': {
                        'tokens': SingleIdTokenIndexer(lowercase_tokens=True, token_min_padding_length=7)},
                    'embedder': EmbedderFactory(token, f'{self.d_dataset}/embedding').get_embedder,
                    'max_len': {'agnews': -1, 'imdb': 400, 'mr': -1}
                }
            if token in ['char']:
                return {
                    'tokenizer': WhitespaceTokenizer(),
                    'filter': StopwordFilter(),
                    'token_indexers': {
                        'token_characters': TokenCharactersIndexer(min_padding_length=6, token_min_padding_length=7)},
                    'embedder': EmbedderFactory(token, f'{self.d_dataset}/embedding').get_embedder,
                    'max_len': {'agnews': 70, 'imdb': 400, 'mr': 70}
                }
            if token == 'elmo':
                return {
                    'tokenizer': WhitespaceTokenizer(),
                    'filter': StopwordFilter(),
                    'token_indexers': {
                        'tokens': SingleIdTokenIndexer(lowercase_tokens=True, token_min_padding_length=7),
                        'token_characters': ELMoTokenCharactersIndexer(token_min_padding_length=7)},
                    'embedder': EmbedderFactory(token, f'{self.d_dataset}/embedding').get_embedder,
                    'max_len': {'agnews': 70, 'imdb': 500, 'mr': 70}
                }
            if token in ['bert', 'roberta', 'albert']:
                model_name = self.encoder_config[token]['pretrained_model']
                return {
                    'tokenizer': PretrainedTransformerTokenizer(model_name=model_name),
                    'filter': StopwordFilter(),
                    'token_indexers': {'tokens': PretrainedTransformerIndexer(model_name=model_name)},
                    'embedder': EmbedderFactory(model_name, f'{self.d_dataset}/embedding').get_embedder,
                    'max_len': {'agnews': 150, 'imdb': 400, 'mr': 150},
                }

        self.token_config = token_config
        Tqdm.set_default_mininterval(1000)

    def parse_arg(self):
        parser = argparse.ArgumentParser()
        for name, value in vars(self).items():
            parser.add_argument(f'--{name}', type=type(value), default=value)

        args = parser.parse_args()
        for name, _ in vars(self).items():
            setattr(self, name, getattr(args, name))

    def __repr__(self):
        ret = Color.red(f'config:\n')
        for name, value in vars(self).items():
            if name == 'copyright':
                break
            ret += Color.blue(name) + "=" + Color.magenta(value) + '\t'
        return ret
